{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you were not here for Lab 12, and need to install the graphviz package:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install --user graphviz"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Lab 13 - Decision Trees for regression\n",
"\n",
"For this lab, we will return to the insurance data from Labs 7 and 8. Recall we are trying to predict the insurance cost, a quantitative value. \n",
"\n",
"If you don't have the dataset, download it from GitHub: [https://github.com/stedy/Machine-Learning-with-R-datasets/blob/master/insurance.csv](https://github.com/stedy/Machine-Learning-with-R-datasets/blob/master/insurance.csv)\n",
"\n",
"In this data, each row represents an insurance policy and the 7 columns contain the following information about it:\n",
"- age: age of policy holder\n",
"- sex: sex of policy holder\n",
"- bmi: boday mass index (bmi) of policy holder. bmi is a (sometimes unreliable) measurement of body fat in adults\n",
"- children: number of children (dependents) on the policy\n",
"- smoker: whether the policy holder is a smoker\n",
"- region: region of the country the policy holder lives in\n",
"- charges: price for insurance policy"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from sklearn import tree\n",
"import graphviz\n",
"from graphviz import Source\n",
" \n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from sklearn.tree import export_graphviz\n",
"import sklearn.metrics as met\n",
"from sklearn.metrics import confusion_matrix\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Read the data into a dataframe and display it to make sure it was read in correctly:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" sex | \n",
" bmi | \n",
" children | \n",
" smoker | \n",
" region | \n",
" charges | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 19 | \n",
" female | \n",
" 27.900 | \n",
" 0 | \n",
" yes | \n",
" southwest | \n",
" 16884.92400 | \n",
"
\n",
" \n",
" 1 | \n",
" 18 | \n",
" male | \n",
" 33.770 | \n",
" 1 | \n",
" no | \n",
" southeast | \n",
" 1725.55230 | \n",
"
\n",
" \n",
" 2 | \n",
" 28 | \n",
" male | \n",
" 33.000 | \n",
" 3 | \n",
" no | \n",
" southeast | \n",
" 4449.46200 | \n",
"
\n",
" \n",
" 3 | \n",
" 33 | \n",
" male | \n",
" 22.705 | \n",
" 0 | \n",
" no | \n",
" northwest | \n",
" 21984.47061 | \n",
"
\n",
" \n",
" 4 | \n",
" 32 | \n",
" male | \n",
" 28.880 | \n",
" 0 | \n",
" no | \n",
" northwest | \n",
" 3866.85520 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age sex bmi children smoker region charges\n",
"0 19 female 27.900 0 yes southwest 16884.92400\n",
"1 18 male 33.770 1 no southeast 1725.55230\n",
"2 28 male 33.000 3 no southeast 4449.46200\n",
"3 33 male 22.705 0 no northwest 21984.47061\n",
"4 32 male 28.880 0 no northwest 3866.85520"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"insurance = pd.read_csv(\"insurance.csv\")\n",
"insurance.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Sci-kit learn decision trees require numeric data. How can we convert the categorical columns into numeric data? \n",
"Hint: see Lab 8"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"insurance = pd.get_dummies(insurance, columns = [\"sex\", \"smoker\", \"region\"], drop_first = True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" bmi | \n",
" children | \n",
" charges | \n",
" sex_male | \n",
" smoker_yes | \n",
" region_northwest | \n",
" region_southeast | \n",
" region_southwest | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 19 | \n",
" 27.900 | \n",
" 0 | \n",
" 16884.92400 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" 1 | \n",
" 18 | \n",
" 33.770 | \n",
" 1 | \n",
" 1725.55230 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" 28 | \n",
" 33.000 | \n",
" 3 | \n",
" 4449.46200 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
" 3 | \n",
" 33 | \n",
" 22.705 | \n",
" 0 | \n",
" 21984.47061 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 4 | \n",
" 32 | \n",
" 28.880 | \n",
" 0 | \n",
" 3866.85520 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age bmi children charges sex_male smoker_yes region_northwest \\\n",
"0 19 27.900 0 16884.92400 0 1 0 \n",
"1 18 33.770 1 1725.55230 1 0 0 \n",
"2 28 33.000 3 4449.46200 1 0 0 \n",
"3 33 22.705 0 21984.47061 1 0 1 \n",
"4 32 28.880 0 3866.85520 1 0 1 \n",
"\n",
" region_southeast region_southwest \n",
"0 0 1 \n",
"1 1 0 \n",
"2 1 0 \n",
"3 0 0 \n",
"4 0 0 "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"insurance.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fitting a decision tree with sci-kit learn\n",
"\n",
"We can get just the independent variables (x's) using the following:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" bmi | \n",
" children | \n",
" sex_male | \n",
" smoker_yes | \n",
" region_northwest | \n",
" region_southeast | \n",
" region_southwest | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 19 | \n",
" 27.900 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" 1 | \n",
" 18 | \n",
" 33.770 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" 28 | \n",
" 33.000 | \n",
" 3 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
" 3 | \n",
" 33 | \n",
" 22.705 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 4 | \n",
" 32 | \n",
" 28.880 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age bmi children sex_male smoker_yes region_northwest \\\n",
"0 19 27.900 0 0 1 0 \n",
"1 18 33.770 1 1 0 0 \n",
"2 28 33.000 3 1 0 0 \n",
"3 33 22.705 0 1 0 1 \n",
"4 32 28.880 0 1 0 1 \n",
"\n",
" region_southeast region_southwest \n",
"0 0 1 \n",
"1 1 0 \n",
"2 1 0 \n",
"3 0 0 \n",
"4 0 0 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = insurance.iloc[:,[0,1,2,4,5,6,7,8]]\n",
"X.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we created the decision tree variable (object) and then fit it to our data:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"reg = tree.DecisionTreeRegressor(max_depth = 5)\n",
"reg = reg.fit(X, insurance[\"charges\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you are running Jupyter Hub on your own computer, you may be able to display the decision tree by:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tree.plot_tree(reg)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you are using the Jupyter Hub server, run the following code (which will give an error):"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"scrolled": true
},
"outputs": [
{
"ename": "PermissionError",
"evalue": "[Errno 13] Permission denied",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mPermissionError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mdot_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_graphviz\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_file\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mgraph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgraphviz\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSource\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdot_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mgraph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"insurance.dot\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/.local/lib/python3.4/site-packages/graphviz/files.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, filename, directory, view, cleanup, format, renderer, formatter)\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0mformat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_format\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 188\u001b[0;31m \u001b[0mrendered\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_engine\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilepath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformatter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 189\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcleanup\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.local/lib/python3.4/site-packages/graphviz/backend.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(engine, format, filepath, renderer, formatter, quiet)\u001b[0m\n\u001b[1;32m 181\u001b[0m \"\"\"\n\u001b[1;32m 182\u001b[0m \u001b[0mcmd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrendered\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcommand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mengine\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilepath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformatter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 183\u001b[0;31m \u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcmd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcapture_output\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcheck\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquiet\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mquiet\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 184\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mrendered\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.local/lib/python3.4/site-packages/graphviz/backend.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(cmd, input, capture_output, check, quiet, **kwargs)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m \u001b[0mproc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msubprocess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcmd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstartupinfo\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mget_startupinfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mOSError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0merrno\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0merrno\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mENOENT\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.4/subprocess.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, args, bufsize, executable, stdin, stdout, stderr, preexec_fn, close_fds, shell, cwd, env, universal_newlines, startupinfo, creationflags, restore_signals, start_new_session, pass_fds)\u001b[0m\n\u001b[1;32m 854\u001b[0m \u001b[0mc2pread\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc2pwrite\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 855\u001b[0m \u001b[0merrread\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merrwrite\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 856\u001b[0;31m restore_signals, start_new_session)\n\u001b[0m\u001b[1;32m 857\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 858\u001b[0m \u001b[0;31m# Cleanup if the child failed starting.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.4/subprocess.py\u001b[0m in \u001b[0;36m_execute_child\u001b[0;34m(self, args, executable, preexec_fn, close_fds, pass_fds, cwd, env, startupinfo, creationflags, shell, p2cread, p2cwrite, c2pread, c2pwrite, errread, errwrite, restore_signals, start_new_session)\u001b[0m\n\u001b[1;32m 1462\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1463\u001b[0m \u001b[0merr_msg\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m': '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mrepr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morig_executable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1464\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mchild_exception_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merrno_num\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merr_msg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1465\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mchild_exception_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merr_msg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1466\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mPermissionError\u001b[0m: [Errno 13] Permission denied"
]
}
],
"source": [
"dot_data = tree.export_graphviz(reg, out_file=None) \n",
"graph = graphviz.Source(dot_data) \n",
"graph.render(\"insurance.dot\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, despite the error, there should now be a file called happiness.dot in your directory. To view the fitted decision tree, open the happiness.dot file in Jupyter and copy the text. Paste this text into the text box at [http://www.webgraphviz.com](http://www.webgraphviz.com) and click the \"Generate graph!\" button at the bottom.\n",
"\n",
"The column names have been replaced by `X[0], X[1], ..., X[7]`. Run the following code to change `X[0], X[1], ..., X[7]` to the column names in insurance.dot."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"with open (\"insurance.dot\", \"r\") as fin:\n",
" with open(\"insurance_fixed.dot\",\"w\") as fout:\n",
" for line in fin.readlines():\n",
" line = line.replace(\"X[0]\",\"age\")\n",
" line = line.replace(\"X[1]\",\"bmi\")\n",
" line = line.replace(\"X[2]\",\"children\")\n",
" line = line.replace(\"X[3]\",\"sex_male\")\n",
" line = line.replace(\"X[4]\",\"smoker_yes\")\n",
" line = line.replace(\"X[5]\",\"region_northwest\") \n",
" line = line.replace(\"X[4]\",\"region_southeast\")\n",
" line = line.replace(\"X[5]\",\"region_southwest\")\n",
" fout.write(line)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copy the contents of insurance_fixed.dot into the textbox in [http://www.webgraphviz.com](http://www.webgraphviz.com) to display the decision tree with the column names. How does it compare the the decision tree you made?\n",
"\n",
"What happens if you change the `max_depth` parameter to 5 in DecisionTreeRegressor?\n",
"\n",
"Look at the leaves of your new tree. What's the smallest sample? \n",
"\n",
"A few of the leaves only have 1 sample. How do you think this tree would work on other insurance data?\n",
"\n",
"The single samples are a sign of over-fitting, and to fix it we can make `max_depth` smaller (but too small and our model will not be as good as it could be).\n",
"\n",
"### Testing and training data\n",
"\n",
"To figure out what `max_depth` should be, let's split our data into training and testing data. "
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, insurance[\"charges\"], test_size=0.2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create a decision tree with `max_depth = 3` from the training data:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"reg3 = tree.DecisionTreeRegressor(max_depth = 3)\n",
"reg3 = reg3.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Make predictions for the test data:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"array([13785.78482646, 6161.11114528, 18723.54895898, 6161.11114528,\n",
" 6161.11114528, 13785.78482646, 2795.03725199, 2795.03725199,\n",
" 13785.78482646, 6161.11114528, 2795.03725199, 2795.03725199,\n",
" 6161.11114528, 13785.78482646, 6161.11114528, 18723.54895898,\n",
" 10351.84377925, 13785.78482646, 13785.78482646, 38722.41603063,\n",
" 38722.41603063, 13785.78482646, 2795.03725199, 6161.11114528,\n",
" 6161.11114528, 6161.11114528, 13785.78482646, 13785.78482646,\n",
" 10351.84377925, 6161.11114528, 18723.54895898, 45606.72260404,\n",
" 18723.54895898, 2795.03725199, 10351.84377925, 6161.11114528,\n",
" 6161.11114528, 13785.78482646, 45606.72260404, 6161.11114528,\n",
" 10351.84377925, 45606.72260404, 18723.54895898, 6161.11114528,\n",
" 6161.11114528, 6161.11114528, 2795.03725199, 10351.84377925,\n",
" 18723.54895898, 38722.41603063, 45606.72260404, 13785.78482646,\n",
" 13785.78482646, 13785.78482646, 6161.11114528, 2795.03725199,\n",
" 2795.03725199, 6161.11114528, 24603.2390669 , 13785.78482646,\n",
" 13785.78482646, 13785.78482646, 6161.11114528, 24603.2390669 ,\n",
" 2795.03725199, 2795.03725199, 13785.78482646, 6161.11114528,\n",
" 13785.78482646, 6161.11114528, 10351.84377925, 6161.11114528,\n",
" 10351.84377925, 6161.11114528, 6161.11114528, 2795.03725199,\n",
" 2795.03725199, 6161.11114528, 18723.54895898, 13785.78482646,\n",
" 6161.11114528, 6161.11114528, 24603.2390669 , 6161.11114528,\n",
" 24603.2390669 , 24603.2390669 , 6161.11114528, 13785.78482646,\n",
" 13785.78482646, 13785.78482646, 18723.54895898, 6161.11114528,\n",
" 6161.11114528, 13785.78482646, 38722.41603063, 13785.78482646,\n",
" 13785.78482646, 10351.84377925, 24603.2390669 , 6161.11114528,\n",
" 6161.11114528, 13785.78482646, 6161.11114528, 18723.54895898,\n",
" 10351.84377925, 13785.78482646, 18723.54895898, 2795.03725199,\n",
" 13785.78482646, 6161.11114528, 6161.11114528, 6161.11114528,\n",
" 24603.2390669 , 6161.11114528, 13785.78482646, 13785.78482646,\n",
" 6161.11114528, 6161.11114528, 45606.72260404, 13785.78482646,\n",
" 6161.11114528, 38722.41603063, 6161.11114528, 6161.11114528,\n",
" 6161.11114528, 2795.03725199, 6161.11114528, 6161.11114528,\n",
" 2795.03725199, 13785.78482646, 6161.11114528, 6161.11114528,\n",
" 45606.72260404, 13785.78482646, 38722.41603063, 13785.78482646,\n",
" 10351.84377925, 13785.78482646, 13785.78482646, 6161.11114528,\n",
" 13785.78482646, 2795.03725199, 2795.03725199, 6161.11114528,\n",
" 10351.84377925, 18723.54895898, 10351.84377925, 2795.03725199,\n",
" 10351.84377925, 18723.54895898, 10351.84377925, 2795.03725199,\n",
" 2795.03725199, 6161.11114528, 6161.11114528, 10351.84377925,\n",
" 10351.84377925, 6161.11114528, 6161.11114528, 10351.84377925,\n",
" 10351.84377925, 18723.54895898, 13785.78482646, 38722.41603063,\n",
" 6161.11114528, 24603.2390669 , 10351.84377925, 13785.78482646,\n",
" 6161.11114528, 6161.11114528, 6161.11114528, 10351.84377925,\n",
" 6161.11114528, 13785.78482646, 6161.11114528, 13785.78482646,\n",
" 13785.78482646, 45606.72260404, 38722.41603063, 6161.11114528,\n",
" 2795.03725199, 24603.2390669 , 10351.84377925, 38722.41603063,\n",
" 10351.84377925, 38722.41603063, 6161.11114528, 10351.84377925,\n",
" 6161.11114528, 10351.84377925, 6161.11114528, 13785.78482646,\n",
" 38722.41603063, 38722.41603063, 2795.03725199, 13785.78482646,\n",
" 10351.84377925, 13785.78482646, 10351.84377925, 13785.78482646,\n",
" 6161.11114528, 2795.03725199, 10351.84377925, 13785.78482646,\n",
" 10351.84377925, 2795.03725199, 10351.84377925, 6161.11114528,\n",
" 2795.03725199, 45606.72260404, 13785.78482646, 6161.11114528,\n",
" 13785.78482646, 13785.78482646, 10351.84377925, 6161.11114528,\n",
" 24603.2390669 , 2795.03725199, 13785.78482646, 13785.78482646,\n",
" 13785.78482646, 6161.11114528, 6161.11114528, 6161.11114528,\n",
" 13785.78482646, 6161.11114528, 10351.84377925, 6161.11114528,\n",
" 2795.03725199, 13785.78482646, 2795.03725199, 10351.84377925,\n",
" 38722.41603063, 13785.78482646, 13785.78482646, 45606.72260404,\n",
" 13785.78482646, 24603.2390669 , 6161.11114528, 6161.11114528,\n",
" 6161.11114528, 6161.11114528, 24603.2390669 , 10351.84377925,\n",
" 6161.11114528, 6161.11114528, 6161.11114528, 10351.84377925,\n",
" 24603.2390669 , 38722.41603063, 10351.84377925, 2795.03725199,\n",
" 6161.11114528, 2795.03725199, 45606.72260404, 2795.03725199,\n",
" 24603.2390669 , 2795.03725199, 2795.03725199, 6161.11114528,\n",
" 10351.84377925, 13785.78482646, 10351.84377925, 10351.84377925,\n",
" 13785.78482646, 18723.54895898, 10351.84377925, 2795.03725199])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions_3 = reg3.predict(X_test)\n",
"predictions_3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute the mean squared error for these predictions:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"27792662.098281976"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"((y_test - predictions_3)**2).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What is the mean squared error if you use `max_depth = 4`?"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"29737038.361192185"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg4 = tree.DecisionTreeRegressor(max_depth = 4)\n",
"reg4 = reg4.fit(X_train, y_train)\n",
"predictions_4 = reg4.predict(X_test)\n",
"((y_test - predictions_4)**2).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What is the mean squared error if you use `max_depth = 5`?"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"32480100.453982316"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg5 = tree.DecisionTreeRegressor(max_depth = 5)\n",
"reg5 = reg5.fit(X_train, y_train)\n",
"predictions_5 = reg5.predict(X_test)\n",
"((y_test - predictions_5)**2).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What about if you use `max_depth = 2`?"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"32211767.651123475"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg2 = tree.DecisionTreeRegressor(max_depth = 2)\n",
"reg2 = reg2.fit(X_train, y_train)\n",
"predictions_2 = reg2.predict(X_test)\n",
"((y_test - predictions_2)**2).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Which `max_depth` parameter should you use? What is the corresponding decision tree?\n",
"\n",
"You can also use a loop to quickly check the different parameter values for `max_depth`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dot_data = tree.export_graphviz(reg_depth3, out_file=None) \n",
"graph = graphviz.Source(dot_data) \n",
"graph.render(\"insurance_depth3.dot\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open (\"insurance_depth3.dot\", \"r\") as fin:\n",
" with open(\"insurance_depth3_fixed.dot\",\"w\") as fout:\n",
" for line in fin.readlines():\n",
" line = line.replace(\"X[0]\",\"age\")\n",
" line = line.replace(\"X[1]\",\"bmi\")\n",
" line = line.replace(\"X[2]\",\"children\")\n",
" line = line.replace(\"X[3]\",\"sex_male\")\n",
" line = line.replace(\"X[4]\",\"smoker_yes\")\n",
" line = line.replace(\"X[5]\",\"region_northwest\") \n",
" line = line.replace(\"X[4]\",\"region_southeast\")\n",
" line = line.replace(\"X[5]\",\"region_southwest\")\n",
" fout.write(line)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we can compare the mean squared error using a Decision Tree regressor to the mean squared error computed using linear regression in Lab 8, also based on a training/testing split of 0.2. It was 41142821.67547247 (for my training/testing data).\n",
"\n",
"Which model is better?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Return to the decision tree classifier from last lab. Which `max_depth` is best?"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}